{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"background-color: #EEEFFF; padding: 10px;\">\n",
    "\\begin{align*}\n",
    "    \\Large{(2025)}\n",
    "\\end{align*}\n",
    "\n",
    "\\begin{align*}\n",
    "    \\Large{Information\\hspace{0.3cm} Fusion}\n",
    " \\end{align*}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\\begin{align*}\n",
    "    \\Large{Fusion-Enhanced\\hspace{0.3cm} Multi-Label  \\hspace{0.3cm}Feature \\hspace{0.3cm} Selection  \\hspace{0.3cm} With  \\hspace{0.3cm} Sparse  \\hspace{0.3cm}Supplementation}\n",
    "\\end{align*}\n",
    "\n",
    "\\begin{align*}\n",
    "    \\Large{\\min_W\\lVert \\textbf{XW} - \\textbf{Y} \\lVert} + \\alpha \\mathrm{Tr}(\\textbf{WLW}^\\top) + \\beta\\lVert\\textbf{W}\\lVert_{2,1} + \\gamma(\\lVert \\textbf{WW}^\\top\\lVert_1 - \\lVert \\textbf{W} \\lVert^2_F)  \\quad\\quad {\\text{s.t.}\\quad} \\textbf{W}\\geq 0\n",
    "\\end{align*}\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "$ $"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import json\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "from pathlib import Path\n",
    "from scipy.io import loadmat\n",
    "from sklearn.cluster import KMeans\n",
    "from skmultilearn.adapt import MLkNN\n",
    "from sklearn.metrics import f1_score\n",
    "from sklearn.metrics import precision_score\n",
    "from sklearn.metrics import label_ranking_loss\n",
    "from sklearn.metrics import hamming_loss\n",
    "from sklearn.metrics import average_precision_score\n",
    "from sklearn.metrics import coverage_error\n",
    "from sklearn.metrics import zero_one_loss\n",
    "import sklearn\n",
    "from numpy.matlib import repmat\n",
    "import pandas as pd\n",
    "import warnings\n",
    "from sklearn.svm import LinearSVC\n",
    "from sklearn.multiclass import OneVsRestClassifier\n",
    "from sklearn.svm import SVC\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU\n"
     ]
    }
   ],
   "source": [
    "## GPU or CPU\n",
    "\n",
    "GPU = False\n",
    "if GPU:\n",
    "    torch.backends.cudnn.enabled = True\n",
    "    torch.backends.cudnn.benchmark = True\n",
    "    os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "    print(\"num GPUs\", torch.cuda.device_count())\n",
    "    dtype = torch.cuda.FloatTensor\n",
    "else:\n",
    "    dtype = torch.FloatTensor\n",
    "    print(\"CPU\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def corr(Y):\n",
    "    C = sklearn.metrics.pairwise.cosine_similarity(Y.T)\n",
    "    Dc = np.diag(np.sum(C, axis=1))\n",
    "    \n",
    "    return C, Dc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " MLkNN Classifier:\n",
      "\n",
      " Micro-F1: 0.3126222355950053 \n",
      " Macro-F1: 0.15221819996401031 \n",
      " Average Precision: 0.33971767546579673 \n",
      " Hamming Loss: 0.058576923076923075 \n",
      " Ranking Loss: 0.219754663520228 \n",
      " Zero-One Loss: 0.8003333333333333 \n",
      " Coverage Error: 8.298333333333334\n",
      "\n",
      " SVM Classifier:\n",
      "\n",
      " Micro-F1: 0.3198599618077657 \n",
      " Macro-F1: 0.14176939705511035 \n",
      " Hamming Loss: 0.054794871794871795 \n",
      " Zero-One Loss: 0.8113333333333334\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAEMCAYAAAAs8rYIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAkRklEQVR4nO3de5xcdX3/8dd79jq535awZBd2A0kUqAWJSEWtogjyqwXUn2Jb0V5+UYsW+vPSotZSlWpbLy1t5ScKCv2hlJ8gUAUUFKwXLoYQhRAJQRLJPSEh2SR7n8/vj3M2mSyT7Gyyu2d35v18POYxZ77ne2Y+X4a85+z3nDmjiMDMzKpLLusCzMxs7Dn8zcyqkMPfzKwKOfzNzKqQw9/MrAo5/M3MqpDD38ysCjn8bdRImiTpMkk/lrRdUq+kzZLulPRuSbVZ12hWrfyPz0aFpBOA7wILgXuBzwDbgKOA1wNfA04EPpJVjWbVTP6Gr400SXngUeB44O0RcWuJPi8DXhYRXxrr+kaLpKkR0ZF1HWbl8LSPjYY/AxYBny8V/AAR8fPBwS/pAkk/lbQ7vf1U0vmDt5W0RtL9kl4k6buSOiTtlPQtSUcX9XufpJD0+yWeIydpnaTlg9oXS/q2pG2SuiU9Keljg6eo0tdfI2l++rrbgV1F639X0gOSOiVtkvQvkk5K67li0HMprfURSXvT8dwn6bWD+rUNbC/p9yT9XFKXpI2S/qnUNJqkEyR9LR1rj6QNkm6XdNrhjNsqh8PfRsNb0/tryt1A0p8D3wZmAZ8GPpUu3yZpSYlN5gH3A78BPgx8A3gzcENRn5uAbuDiEtu/Ln2O64tqOA/4KclU1eeBvwAeAD4JfLPEc0wBfgT0AR8Drkif55XA94F24LMkU16Li19rkP8A/g1YTTINdgUwHbin1AcXcB5wHXAX8JfAL4APMWgKTdJi4BHg7ST/bT8A/CvQALziCMZtlSAifPNtRG/Ac8CuYfSfCewmCb9pRe3TgKeBDmBGUfsaIIC3DXqef0/bX1TU9v+ALmDmoL7/AfQCc9PHjcAm4L+B2kF9/zJ93tcUtd2ftn26xHgeTl9zflFbHUnABnBFUfuFaduSQc9RCywFnmH/9Gxb2ncP0FbUV8DjwMYSbV3AS0rUmDuccftWOTfv+dtomEbRFEgZzgYmA1dFxL7t0uV/JdnDfv2gbTZExM2D2n6Y3p9Q1HY9yZ7u2wcaJE0hCd27I2JzUQ1zSQ5Ez5A0Z+AG3Jn2eUOJ2j9X/EDSXOBlwO0R8euisfQC/1Ji+z8i+XC7bdBrzgD+iyTwFwza5raIWFP03AHcBxydjg3gFOAk4GsR8cvBLxoRhSMct01wns+z0bALmDqM/u3p/YoS6x5P7+cPav/14I4kf3EAzC5quxvYQjL183/StreQfNgUT8O8OL2/7hB1zh30eGtEPD+obWAsT5bYvlTbi0n+W20usa74dVcVPR5q7LvZ/4Hx6CGed+D1YXjjtgrg8LfR8Djwaknzi/d+D0GH8Rr95TxfRPRJ+gZwmaQTImI1yQfBDpI968HbfBhYfpDn3TDo8d5DvXaZBGwF/uAQfR4f9LicsQ/cD3U63+GM2yqAw99Gwy3Aq0nO+vloGf2fTu9PAn4waN2J6X05HyIHcz1wGXCxpGuA1wDXRER3UZ+n0vs9EXHvEbzWQJ2LSqwr1fYUyYHWByNi9xG87mADf2WcOkS/kRq3TTCe87fR8FWS8PlQqVM1ASSdlp7hA3APyUHMD0iaWtRnKskZKrvTPoclIpYDvySZX7+Y5P/7wWfefI9keuivJc0qUW++uLZDvNZmkgO150vaN1UlqQ64tMQmN6T1fKbU86XHEA7HL0im0f5E0kklnndgj39Exm0Tj/f8bcRFxF5Jv0fyDd/bJH2fJLyfA5qA1wLnAP+Y9n9e0kdIztZ5SNLX06d6N8nB2/dExM4jLOt6ktMY/wpYFREPDqp5j6SLgduAJyVdR3L20QzgRSSnkV5IcpbPUD5EMt6fSfoSsBN4G1A/8HJFr/stSV8D3i/ppcB3SL4J3QL8Dsn4Bx/vGFJEhKQ/JvlL6mFJ15JMH80AfpfkWMi/jvC4bSLJ+nQj3yr3BkwiOV3wJyRz7L0kBza/C7wTqBnU/0LgZyR/BexJly8o8bxrgPtLtL+GJFjfXWLd3PT1A/jYIWo+Gfi/wHqgJ633Z8DfALOK+t0PrDnE85wFPEhyquVmkjN9Xp6+/kdK9H8n8GOSg+Vd6RhvJfmG9ECfNgadKlq07op0Xdug9kXpeDal49lAEvQvPZxx+1Y5N1/ewWyMSHoL8C3gHRFxU9b1WHXznL/ZCEsv19A4qK0O+N8k3wa+P4u6zIp5zt9s5DUAayXdSHLgezbJl8xeAvxDRGzKsjgzcPibjYZekuMa5wPNJOfSPwlcEhV0FVOb2Dznb2ZWhcb9nv+cOXOira0t6zLMzCaURx55ZFtENB1s/bgP/7a2NpYuXZp1GWZmE4qktYda77N9zMyqkMPfzKwKOfzNzKqQw9/MrAo5/M3MqpDD38ysCjn8zcyqUMWG/9Zb7mLHvT/Nugwzs3GpYsN/+x0/YPtd92VdhpnZuFSx4d/QejTdz27Mugwzs3GpgsP/GHo2biH6+rIuxcxs3Kng8G+G/gLdG7ZkXYqZ2bhT2eEPnvoxMyuhcsO/JQ3/dQ5/M7PBKjb8a2ZMo2bqZHq8529m9gIVG/6SaGhppuvZDVmXYmY27lRs+EMy7+85fzOzF6ro8K9vaaZ38zYK3T1Zl2JmNq5UdPg3tDZDBD0bNmddipnZuFLR4d/Yegzg0z3NzAar6PCv33euvw/6mpkVq+jwr506mZoZ07znb2Y2SEWHPyRf9vIXvczMDlT54e/TPc3MXqAqwr9363b6O7uyLsXMbNyoivAH6Fm3KeNKzMzGj8oP/xZf3dPMbLDqCX8f9DUz26fiw79mcp7aOTO9529mVqTiwx/S0z39RS8zs32GDH9JrZLuk7RS0gpJl6bt/ylpeXpbI2l50TaXS1ot6UlJ5xS1nybpsXTdVZI0KqMaxKd7mpkdqLaMPn3AByNimaSpwCOS7omItw90kPR5YGe6fCJwEXAScAxwr6SFEdEPXA0sAR4E7gTOBe4ayQGV0tDSTN+OnfTv2UvN5Emj/XJmZuPekHv+EbExIpalyx3ASmDewPp07/1twDfTpvOBmyKiOyKeAVYDp0tqBqZFxAMREcANwAUjOZiD8e/5mpkdaFhz/pLagFOBh4qaXwVsjoin0sfzgGeL1q9L2+aly4PbS73OEklLJS3dunXrcEosqcFX9zQzO0DZ4S9pCnALcFlE7Cpa9Q727/UDlJrHj0O0v7Ax4pqIWBwRi5uamsot8aAaWuYCvrqnmdmAcub8kVRHEvw3RsStRe21wJuB04q6rwNaix63ABvS9pYS7aMu19BA3dw53vM3M0uVc7aPgGuBlRHxhUGrXw/8KiKKp3PuAC6S1CCpHVgAPBwRG4EOSWekz3kxcPuIjKIMvrqnmdl+5Uz7nAm8Ezir6NTO89J1F3HglA8RsQK4GXgCuBu4JD3TB+B9wFdJDgI/zRic6TPAp3uame035LRPRPyE0vP1RMS7D9J+JXBlifalwMnDK3FkNLQ2079rN307O6idPjWLEszMxo2q+IYv+HRPM7Ni1RP+Lf49XzOzAVUT/vXz5kIu54O+ZmZUUfjn6uqoP7rJ0z5mZlRR+IPP+DEzG1CV4Z9cWsjMrHpVV/i3NFPY20nfjp1Zl2JmlqnqCn+f7mlmBlRd+PvqnmZmUGXhX390E9TU+Fx/M6t6VRX+qq2hYd5c7/mbWdWrqvAHX93TzAyqMfxbm+lZt4koFLIuxcwsM1UZ/oWubnq37ci6FDOzzFRl+IMv8GZm1a36wr/F5/qbmVVd+NfNnYPq63zQ18yqWtWFv3I5GuYd7T1/M6tqVRf+4Kt7mplVbfj3rN9E9PcP3dnMrAJVZ/i3NBO9ffRseS7rUszMMlGV4V+fnu7Z46kfM6tSVRn+jb66p5lVuaoM/9o5M8nlG+nyF73MrEpVZfhLor7laE/7mFnVqsrwB1/d08yqW/WGf2sz3Ru2EH19WZdiZjbmqjr86e+nZ+OWrEsxMxtzQ4a/pFZJ90laKWmFpEuL1n1A0pNp+z8WtV8uaXW67pyi9tMkPZauu0qSRn5I5fHv+ZpZNasto08f8MGIWCZpKvCIpHuAucD5wEsiolvSUQCSTgQuAk4CjgHulbQwIvqBq4ElwIPAncC5wF0jPahy+OqeZlbNhtzzj4iNEbEsXe4AVgLzgPcBn42I7nTdwPzJ+cBNEdEdEc8Aq4HTJTUD0yLigYgI4AbggpEeULlqZ04jN3mSD/qaWVUa1py/pDbgVOAhYCHwKkkPSfqRpJel3eYBzxZtti5tm5cuD27PhCRf4M3MqlbZ4S9pCnALcFlE7CKZMpoJnAF8GLg5ncMvNY8fh2gv9VpLJC2VtHTr1q3lljhsDn8zq1Zlhb+kOpLgvzEibk2b1wG3RuJhoADMSdtbizZvATak7S0l2l8gIq6JiMURsbipqWk44xmWhtZmejZvo9DTO2qvYWY2HpVzto+Aa4GVEfGFolW3AWelfRYC9cA24A7gIkkNktqBBcDDEbER6JB0RvqcFwO3j+RghquhpRkKBXo2bM6yDDOzMVfO2T5nAu8EHpO0PG37KHAdcJ2kx4Ee4F3pgdwVkm4GniA5U+iS9EwfSA4Sfx3Ik5zlk8mZPgP2/5j7RhrbWobobWZWOYYM/4j4CaXn6wH+6CDbXAlcWaJ9KXDycAocTcXhb2ZWTar2G74AtdOmUjN9Kt2+uqeZVZmqDn9IL/DmPX8zqzIO/1Zf3dPMqo/Dv7WZ3i3PUejqzroUM7Mx4/A/NrnAW9fa9RlXYmY2dqo+/PMntAHQ+dSaTOswMxtLVR/+Da3N5PKNdK56JutSzMzGTNWHv3I58iccR+dTDn8zqx5VH/4A+UXz6XzqGaJQyLoUM7Mx4fAH8gvbKeztonvdpqxLMTMbEw5/IL+gHcDz/mZWNRz+QOP8VlRbS+eqX2ddipnZmHD4A7m6OhrbW73nb2ZVw+Gfyi9sp3PVMyRXpTYzq2wO/1R+UTt9z++id9v2rEsxMxt1Dv9UfuF8ADqf9Ly/mVU+h38qf8JxIHne38yqgsM/VTMpT0Nrs8PfzKqCw79IfkE7ex3+ZlYFHP5F8gvb6d20lb5dHVmXYmY2qhz+RSYtSg/6rlqTbSFmZqPM4V8kv3DgMg8+48fMKpvDv0jtjGnUHTXbB33NrOI5/AfJL/RBXzOrfA7/QfIL2+n+zQb/oLuZVTSH/yD5Be1QKNC5em3WpZiZjRqH/yD7z/jxQV8zq1wO/0Hq5s6hZtoU9j7peX8zq1wO/0EkJZd39g+6m1kFGzL8JbVKuk/SSkkrJF2atl8hab2k5entvKJtLpe0WtKTks4paj9N0mPpuqskaXSGdWTyC9vpenot0deXdSlmZqOinD3/PuCDEfFi4AzgEkknpuu+GBGnpLc7AdJ1FwEnAecCX5JUk/a/GlgCLEhv547cUEbOpIXzid4+utasy7oUM7NRMWT4R8TGiFiWLncAK4F5h9jkfOCmiOiOiGeA1cDpkpqBaRHxQCQ/l3UDcMGRDmA07P+mr6d+zKwyDWvOX1IbcCrwUNr0fkm/lHSdpJlp2zzg2aLN1qVt89Llwe2lXmeJpKWSlm7dunU4JY6IhtZmco0N/rKXmVWsssNf0hTgFuCyiNhFMoVzPHAKsBH4/EDXEpvHIdpf2BhxTUQsjojFTU1N5ZY4YlRTQ+MJbf5VLzOrWGWFv6Q6kuC/MSJuBYiIzRHRHxEF4CvA6Wn3dUBr0eYtwIa0vaVE+7g0aWE7nU+tIQqFrEsxMxtx5ZztI+BaYGVEfKGovbmo24XA4+nyHcBFkhoktZMc2H04IjYCHZLOSJ/zYuD2ERrHiMsvbKewt5OeDZuzLsXMbMTVltHnTOCdwGOSlqdtHwXeIekUkqmbNcB7ACJihaSbgSdIzhS6JCL60+3eB3wdyAN3pbdxqfigb0NL8xC9zcwmliHDPyJ+Qun5+jsPsc2VwJUl2pcCJw+nwKw0zj8WamrYu+oZZpz1iqzLMTMbUf6G70Hk6utonN/q0z3NrCI5/A9h0oJ2Op/8NcnXEszMKofD/xDyC9vp27GTvm07si7FzGxEOfwPYeCgr7/sZWaVxuF/CPkFbYAv82Bmlcfhfwg1kyfR0NpM51P+pq+ZVRaH/xDyC9u9529mFcfhP4T8gnZ6Nmyhb9furEsxMxsxDv8h7Pumr3/Zy8wqiMN/CPmFAz/o7vA3s8rh8B9C3azp1M2Z5fA3s4ri8C9DfpEP+ppZZXH4lyG/sJ2utespdHVnXYqZ2Yhw+Jchv3A+FAp0Pr0261LMzEaEw78M/kF3M6s0Dv8y1B/dRM3UKQ5/M6sYDv8ySCK/oM3hb2YVw+Ffpvyi+XQ+/Ruir3/ozmZm45zDv0z5he1ETw9da9dnXYqZ2RFz+Jdp0r6Dvr7Cp5lNfA7/MjUcewxqqGfvkw5/M5v4HP5lUk0NU377xez80UOe9zezCc/hPwyzLzyH3s3b2PnTpVmXYmZ2RBz+wzD9zMXUzZ3DtlvuyroUM7Mj4vAfBtXWMOfCc9i99DG61qzLuhwzs8Pm8B+m2W96HaqrZdutd2ddipnZYXP4D1PtzOnMeN2ZbL/zfvr3dGZdjpnZYXH4H4Y5b30jhb2d7Lj7R1mXYmZ2WBz+h2HyiQvIv+h4tt16FxGRdTlmZsM2ZPhLapV0n6SVklZIunTQ+g9JCklzitoul7Ra0pOSzilqP03SY+m6qyRpZIczdpre+ka6nlnH7mWPZ12KmdmwlbPn3wd8MCJeDJwBXCLpREg+GICzgd8MdE7XXQScBJwLfElSTbr6amAJsCC9nTtC4xhzM153JjXTp/q0TzObkIYM/4jYGBHL0uUOYCUwL139ReAjQPHcx/nATRHRHRHPAKuB0yU1A9Mi4oFI5kpuAC4YsZGMsVxDPbPf9Dp2/vjn9GzelnU5ZmbDMqw5f0ltwKnAQ5J+H1gfEb8Y1G0e8GzR43Vp27x0eXB7qddZImmppKVbt24dToljavaF50AheO6272ddipnZsJQd/pKmALcAl5FMBX0M+ESpriXa4hDtL2yMuCYiFkfE4qampnJLHHMNzUcx7czFPHfHvRR6erMux8ysbGWFv6Q6kuC/MSJuBY4H2oFfSFoDtADLJB1NskffWrR5C7AhbW8p0T6hzXnLufTt2Mnz9z2QdSlmZmUr52wfAdcCKyPiCwAR8VhEHBURbRHRRhLsL42ITcAdwEWSGiS1kxzYfTgiNgIdks5In/Ni4PbRGdbYmfqyl9Bw7DFs+5YP/JrZxFHOnv+ZwDuBsyQtT2/nHaxzRKwAbgaeAO4GLomIgWsgvw/4KslB4KeBCZ+YyuWYc+E57F2xir2/ejrrcszMyqLx/iWlxYsXx9Kl4/sSyn0de3jigiXMOOsVHPuxS7Iux8wMSY9ExOKDrfc3fEdA7dTJzDzn1ey45yf07ezIuhwzsyE5/EfInLecS/T0sP07P8i6FDOzITn8R0j++OOYfOpJbLv1e0S/f+bRzMY3h/8IanrLufRs3MKuBx/NuhQzs0Ny+I+g6a8+nbo5s3zap5mNew7/EaTaWmZfcDYdDy2n+9kJ//01M6tgDv8RNvv8s1FtLdtu/V7WpZiZHZTDf4TVzZ7J9Ne8nOe++0P6OvZkXY6ZWUkO/1Fw1B9eQKGzm2c/8+/+pS8zG5cc/qNg0qL5HPO+P2Ln/Q+x7ebvZl2OmdkLOPxHSdM73sS0V72M9f92A3seX5V1OWZmB3D4jxJJHPvxD1A/dzZr/ubzvuyDmY0rDv9RVDt1Mm2f/hB9259n7Sf/hSgUsi7JzAxw+I+6SS86nnmX/jEdDzzKlv/4dtblmJkBDv8xMfvCc5hx9ivZ+JWb6Fj2eNblmJk5/MeCJFo/8l4aWo5m7Se+SO9zO7IuycyqnMN/jNRMztN25Yfp37OXtX/7RV/508wy5fAfQ/njj6X1w0vYvWwFm776n1mXY2ZVzOE/xmad91pmvel1bL7+FnY9sCzrcsysSjn8M9Dyv/+UxhOOY+0nr6Jn87asyzGzKuTwz0CuoYG2T3+I6O1jzcc/T6G3N+uSzKzKOPwz0njsMRz70T9n74pVrPuHLxN9PgBsZmPH4Z+hGWe9grl/8ja233kfv/7IZ+jf05l1SWZWJRz+GWv+s7fT+lfvpePnv2D1n/8NPVufy7okM6sCDv9xYPb5ZzP/nz5K9/qNPPW/Lqdz9ZqsSzKzCufwHyemnXEqC66+EiJ46r0fZ9dDy7MuycwqmMN/HMkvaGPBVz5L/TFH8esP/T3PfecHWZdkZhXK4T/O1B81mwVXf5qpp53Ms3//JTZ++Rv+KUgzG3FDhr+kVkn3SVopaYWkS9P2T0n6paTlkr4v6ZiibS6XtFrSk5LOKWo/TdJj6bqrJGl0hjWx1UyexPzPfZRZb3o9m6+/hd/83VUUevxdADMbOeXs+fcBH4yIFwNnAJdIOhH4p4h4SUScAnwH+ARAuu4i4CTgXOBLkmrS57oaWAIsSG/njuBYKopqa2n96/fS/J4/YMf3/5un//JT9O3anXVZZlYhhgz/iNgYEcvS5Q5gJTAvInYVdZsMDMxNnA/cFBHdEfEMsBo4XVIzMC0iHohkHuMG4IKRG0rlkcTcd72F4664jL2PP8lTSy737wGb2YgY1py/pDbgVOCh9PGVkp4F/pB0zx+YBzxbtNm6tG1eujy4vdTrLJG0VNLSrVu3DqfEijTzDa/i+H/+BIW9XTy15HKe/ezV/k1gMzsiZYe/pCnALcBlA3v9EfGxiGgFbgTeP9C1xOZxiPYXNkZcExGLI2JxU1NTuSVWtCmnnsSLvnkVTRe9iee++0NWXvQBnvuve/27wGZ2WMoKf0l1JMF/Y0TcWqLLN4C3pMvrgNaidS3AhrS9pUS7lalmcp55f/FuFn39czS2tfDsZ67mqfd+nL2rnsm6NDObYMo520fAtcDKiPhCUfuCom6/D/wqXb4DuEhSg6R2kgO7D0fERqBD0hnpc14M3D5C46gq+eOP44QvfYpjP/5+etZtZNWffIR1X7yW/t17si7NzCaI2jL6nAm8E3hM0vK07aPAn0paBBSAtcB7ASJihaSbgSdIzhS6JCIGLln5PuDrQB64K73ZYZDErPNey7RXvoyNX/4G2751F8//8GfM+8C7mXH2K/FZtGZ2KBrvXyBavHhxLF26NOsyxr29T6zm2c9dQ+evnmbKS0/mmA+8i0mL5mddlpllRNIjEbH4oOsd/pUj+vt57vZ72Pjlb9DfsYfJL3kRc958LtNfewa5urqsyzOzMeTwr0J9u3az/c772HbL3fSs30TtrBnMPv/1zL7gDdQ3zc66PDMbAw7/KhaFAh0P/4Jt37or+bH4nJj+6pfT9NY3MvmUE31cwKyCDRX+5RzwtQlKuRzTzjiVaWecSvf6TWz79vfY/l8/ZOd9D9DY3sqct76Rmee8mppJ+axLNbMx5j3/KlPo6mbHvT9h27fuonPVM+TyjUxZ/FtM+52XMu3lp1DffFTWJZrZCPCevx0g19jA7N97HbP+x1nsfXwV2+++n10PPMquH/8cgIa2ln1/LUz+7ReTa6jPuGIzGw3e8zcigu6169n14KN0PPgou5c/QfT0kmtsYMpLT2bq75zKtJefSkPL0VmXamZl8p6/DUkSjW0tNLa1cNRFb6K/s4vdy1bQ8dCjyV8FP3uE9UDdnFnkF7aTX9ROfuF88gvbqT+6yQeOzSYgh7+9QE2+kelnnsb0M08DoHvdRnY9+Ch7VzxF56pn2PXgo5BeUK5m6hTyC9rIL0o+DCYtbKfh2GNQTc2hXsLMMubwtyE1tDTT9NZmeGvyuNDVTefTa+lc9cy+27Zb7iLSXxtTfR31x8ylYd7c9P5o6uel981H+TiC2Tjg8LdhyzU2MPmkhUw+aeG+tujro2vt+uTD4Om19KzfTPf6TexetoJCZ9cB29c1zaJ+3tE0HDOXuqPnUDdnFnVzZlI3eyZ1c2ZRO3M6qvVfDmajyeFvI0K1teSPP4788ccd0B4R9O3YRc/6TXRv2Jzcr99Mz4bN7Hp4OX3PPQ+DTzrI5aidOe2AD4XaOTOpnTGN2mlTqZ0xjZrpyX3tjKnkGhrGbqBmFcLhb6NKEnWzplM3azqTf2vRC9ZHXx+923fSt20Hvdu2J7fndtA78HjLdvY+sZq+HTsP+hq5xobkw2D6VGqnT6Nm2hRqpk6mZvKk5H7KJGqmTE6X9z/OTZlErrHBB6ytKjn8LVOqraX+qNnUH3Xoaw5FXz/9Hbvp29lB3/O76NvZQX96n9x20f98ct+zeSv9u/fS37Gb6O07dAG5HLl8IzWTGslNyqfLeXKT8ul9I7lJjdTk8+Qa68k1NqLGemryjaihgVy+gVzjoFtDfXJco6bGHyw2bjn8bUJQbQ21M6dTO3P6sLYrdPfQv3sP/R17kg+EfcvJ48LeTvr3dlLY20Whs5P+vV0U9nbSs2lL0rY3aYuenuEXncuRa6hHDfXk6uuKlve3qa6OXH0tqqtD9fVFy3X71qu+jlxtLaqvRbXp+rpkOVeiTbU16S1drqtFNUXtPhPLcPhbhRvYC6+bPfOInif6+yl091Do6k5und0UurspdHZR6Oqh0NW1b11091Lo7iF6epJtunuInt6k/77lHvo79lDo7SV6ktu+5d4+Cj290N8/dGGHI5fb9yGg2prkL5SBx0XLvKAtl/StSZdzuf3ra2qgNofStv3rBvol66itSe5zA333LzPQZ6D/wPMU9Sen/X2lfX3RwL32Pd7Xt6hNuXSbnA7skxt4Ph34XAf00/7nGHjtCfyXncPfrAyqqaEmnQoaK9Hfv++DIPr6iN701le6LXp6ib7+ZLmvb/9yb1/yXMWP+/r3t/X3Q9Hygfd90F8g+gtJv95uCunyvu0G1vf1QaFAFJLHDGyzr60f+gtj9t9vzOSKPyiULufSDwglHxDFHxZpv4EPr33bogM+sJBY9PXPkasfnd/icPibjVMDe9S5xso5myki9n0YMPAhMfC4P/2QKAz+4IhB94VB94Pbi/pH+hoRUIj9fYrrKF4XA33igHXEwPMd2B5Roi+xv4aitgNfo6gtfY792+7vk3wojA6Hv5mNGUn7po4sW7msCzAzs7Hn8Dczq0IOfzOzKuTwNzOrQg5/M7Mq5PA3M6tCDn8zsyrk8Dczq0Lj/gfcJW0F1h7m5nOAbSNYTtYqbTxQeWOqtPFA5Y2p0sYDpcd0XEQ0HWyDcR/+R0LS0kP9ev1EU2njgcobU6WNBypvTJU2Hji8MXnax8ysCjn8zcyqUKWH/zVZFzDCKm08UHljqrTxQOWNqdLGA4cxpoqe8zczs9Iqfc/fzMxKcPibmVWhigx/SedKelLSakl/nXU9I0HSGkmPSVouaWnW9RwOSddJ2iLp8aK2WZLukfRUen9kP7Y7hg4yniskrU/fp+WSzsuyxuGQ1CrpPkkrJa2QdGnaPpHfo4ONaUK+T5IaJT0s6RfpeP4ubR/2e1Rxc/6SaoBVwNnAOuDnwDsi4olMCztCktYAiyNiwn45RdKrgd3ADRFxctr2j8D2iPhs+kE9MyL+Kss6y3WQ8VwB7I6Iz2VZ2+GQ1Aw0R8QySVOBR4ALgHczcd+jg43pbUzA90nJL8ZPjojdkuqAnwCXAm9mmO9RJe75nw6sjohfR0QPcBNwfsY1GRAR/w1sH9R8PnB9unw9yT/MCeEg45mwImJjRCxLlzuAlcA8JvZ7dLAxTUiR2J0+rEtvwWG8R5UY/vOAZ4ser2MCv9lFAvi+pEckLcm6mBE0NyI2QvIPFTgq43pGwvsl/TKdFpowUyTFJLUBpwIPUSHv0aAxwQR9nyTVSFoObAHuiYjDeo8qMfxL/dx9JcxtnRkRLwXeCFySTjnY+HM1cDxwCrAR+Hym1RwGSVOAW4DLImJX1vWMhBJjmrDvU0T0R8QpQAtwuqSTD+d5KjH81wGtRY9bgA0Z1TJiImJDer8F+DbJ9FYl2JzOyw7Mz27JuJ4jEhGb03+cBeArTLD3KZ1HvgW4MSJuTZsn9HtUakwT/X0CiIjngfuBczmM96gSw//nwAJJ7ZLqgYuAOzKu6YhImpwerELSZOANwOOH3mrCuAN4V7r8LuD2DGs5YgP/AFMXMoHep/Rg4rXAyoj4QtGqCfseHWxME/V9ktQkaUa6nAdeD/yKw3iPKu5sH4D0tK1/BmqA6yLiymwrOjKS5pPs7QPUAt+YiGOS9E3gNSSXn90M/C1wG3AzcCzwG+B/RsSEOIh6kPG8hmQqIYA1wHsG5mLHO0mvBH4MPAYU0uaPksyRT9T36GBjegcT8H2S9BKSA7o1JDvvN0fEJyXNZpjvUUWGv5mZHVolTvuYmdkQHP5mZlXI4W9mVoUc/mZmVcjhb2ZWhRz+ZmZVyOFvZlaF/j+ExpkRRELD/wAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "datasetML = ['Arts','Science','Entertainment','Health','Recreation' ,'Reference','Social','Society','Corel5k','Business','Education']\n",
    "\n",
    "for db in range(0,1):\n",
    "    \n",
    "    dataset_name = Path(f'{datasetML[db]}').stem\n",
    "    data = loadmat(f'../../datasets/{datasetML[db]}')\n",
    "    train = data['train']\n",
    "    test = data['test']\n",
    "\n",
    "    X_test = train[0][0].T\n",
    "    Y_test = train[0][1].T\n",
    "    Y_test[Y_test == -1] = 0\n",
    "\n",
    "    X = torch.from_numpy(test[0][0].T).type(dtype)\n",
    "\n",
    "    Y = torch.from_numpy(test[0][1].T).type(dtype)\n",
    "    Y[Y == -1] = 0\n",
    "\n",
    "    Ycn = Y.cpu().numpy()\n",
    "    Xc = test[0][0].T\n",
    "\n",
    "    #Feature\n",
    "    n,d = X.shape\n",
    "    #label\n",
    "    n,l = Y.shape\n",
    "\n",
    "    #label correlation\n",
    "    epsilon = torch.tensor(torch.finfo(torch.float32).eps)\n",
    "    \n",
    "    #cosine simmilarity :\n",
    "    S,A = corr(Y)\n",
    "    L = A-S\n",
    "\n",
    "#     alpha  = [0.01,0.1,0.3,0.5,0.7,0.9,1]\n",
    "#     beta   = [0.01,0.1,0.3,0.5,0.7,0.9,1]\n",
    "#     gamma  = [0.01,0.1,0.3,0.5,0.7,0.9,1]\n",
    "\n",
    "    alpha = [0.1]\n",
    "    beta  = [0.1]\n",
    "    gamma = [0.1]\n",
    "\n",
    "    t=30\n",
    "    for a in alpha:\n",
    "        for b in beta:\n",
    "            for g in gamma:\n",
    "                erro= torch.zeros(t)\n",
    "                for avg in range(1):\n",
    "                    dd =torch.ones((d,d))\n",
    "                    W = torch.rand(d,l)\n",
    "                    for i in range(t):\n",
    "                        D  = torch.diag(1 / (2 * torch.norm(W, dim=1) + epsilon))\n",
    "                        Up = (X.T @ Y + a * (W @ S) + g * (W))\n",
    "                        Dw = (X.T @ X @ W + a * ( W @ A) + b * (D @ W) + g * (dd @ W))\n",
    "                        W  = W * (Up / torch.maximum(Dw , epsilon))\n",
    "\n",
    "                        T1 = torch.norm(X @ W - Y) ** 2\n",
    "                        T2 = a * (torch.trace(W @ (1-S) @ W.T))\n",
    "                        T3 = b  * (torch.sum(torch.norm(W, dim=1)))\n",
    "                        T4 = g * (torch.sum(torch.abs(W.T @ W)) - torch.norm(W)**2)\n",
    "                        erro[i] = T1 + T2 + T3 + T4 \n",
    "\n",
    "                    plt.title('Convergence', fontsize=18)\n",
    "                    plt.plot(erro[:],color=\"#cf3952\")\n",
    "                    WW = torch.norm(W, dim=1, p=2)\n",
    "                    sQ = torch.argsort(WW)\n",
    "                    sQ = sQ.cpu()\n",
    "\n",
    "                    nosf = int (20 * d / 100)\n",
    "                    sX = X[:,sQ[d-nosf:]]\n",
    "                    \n",
    "                    classifier = MLkNN(k=10)\n",
    "                    classifier.fit(sX, Ycn.astype(int))\n",
    "                    # predict\n",
    "                    predictions = classifier.predict(X_test[:,sQ[d-nosf:].long()]).toarray()\n",
    "                    scores = classifier.predict_proba(X_test[:,sQ[d-nosf:].long()]).toarray()\n",
    "                    MIC = f1_score(Y_test, predictions, average='micro')\n",
    "                    MAC = f1_score(Y_test, predictions, average='macro')\n",
    "                    AVP = average_precision_score(Y_test.T,scores.T)\n",
    "                    HML = hamming_loss(Y_test,predictions)\n",
    "                    RNL = label_ranking_loss(Y_test,scores)\n",
    "                    ZER = zero_one_loss(Y_test,predictions)\n",
    "                    CVE = coverage_error(Y_test,scores)\n",
    "                    print('\\n','MLkNN Classifier:')\n",
    "                    print('\\n','Micro-F1:',MIC,'\\n','Macro-F1:',MAC,'\\n','Average Precision:',AVP,'\\n','Hamming Loss:',HML,'\\n','Ranking Loss:',RNL,'\\n','Zero-One Loss:',ZER,'\\n','Coverage Error:',CVE)\n",
    "                    \n",
    "                    #SVM\n",
    "                    classifiervm = OneVsRestClassifier(LinearSVC())\n",
    "                    classifiervm.fit(sX, Ycn.astype(int))\n",
    "                    predictionsvm = classifiervm.predict(X_test[:,sQ[d-nosf:].long()])\n",
    "                    MICvm = f1_score(Y_test, predictionsvm, average='micro')\n",
    "                    MACvm = f1_score(Y_test, predictionsvm, average='macro')\n",
    "                    HMLvm = hamming_loss(Y_test,predictionsvm)\n",
    "                    ZERvm = zero_one_loss(Y_test,predictionsvm)\n",
    "                    print('\\n','SVM Classifier:')\n",
    "                    print('\\n','Micro-F1:',MICvm,'\\n','Macro-F1:',MACvm,'\\n','Hamming Loss:',HMLvm,'\\n','Zero-One Loss:',ZERvm)\n",
    "                    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
